import os
import contextlib
import copy
import numpy as np

from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import pycocotools.mask as mask_util

from utils import all_gather

class CocoEvaluator():
    def __init__(self, coco_gt, iou_types):
        assert isinstance(iou_types, (list, tuple))
        coco_gt = copy.deepcopy(coco_gt)
        self.coco_gt = coco_gt
        self.iou_types = iou_types
        self.coco_eval = {}
        for iou_type in iou_types:
            self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
        self.img_ids = []
        self.eval_imgs = {k: [] for k in iou_types}

        self.ids2cats = {id:cat for id, cat in enumerate(self.coco_gt.getCatIds())}
        self.cats2ids = {cat:id for id, cat in enumerate(self.coco_gt.getCatIds())}

    def update(self, predictions):
        img_ids = list(np.unique(list(predictions.keys())))
        self.img_ids.extend(img_ids)

        for iou_type in self.iou_types:
            results = self.prepare(predictions, iou_type)

            with open(os.devnull, 'w') as devnull:
                with contextlib.redirect_stdout(devnull):
                    coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
            coco_eval = self.coco_eval[iou_type]

            coco_eval.cocoDt = coco_dt
            coco_eval.params.imgIds = list(img_ids)
            img_ids, eval_imgs = evaluate(coco_eval)
            #print('eval_imgs shape: ', eval_imgs.shape)

            self.eval_imgs[iou_type].append(eval_imgs)

    def synchronize_between_processes(self):
        for iou_type in self.iou_types:
            self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
            create_common_coco_eval(self.coco_eval[iou_type],
                                    self.img_ids,
                                    self.eval_imgs[iou_type])

    def accumulate(self):
        for coco_eval in self.coco_eval.values():
            coco_eval.accumulate()

    def summarize(self):
        stats_dict = {}
        for iou_type, coco_eval in self.coco_eval.items():
            print(f'IoU metric: {iou_type}')
            coco_eval.summarize()
            stats_dict[iou_type] = coco_eval.stats
        return stats_dict

    def prepare(self, predictions, iou_type):
        if iou_type == 'bbox':
            return self.prepare_for_coco_detection(predictions)
        elif iou_type == 'segm':
            return self.prepare_for_coco_segmentation(predictions)
        elif iou_type == 'keypoints':
            return self.prepare_for_coco_keypoint(predictions)
        else:
            raise ValueError(f'Unknown iou type {iou_type}')

    def prepare_for_coco_detection(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue
            boxes = prediction['boxes']
            boxes = convert_to_xywh(boxes).tolist()
            scores = prediction['scores'].tolist()
            labels = prediction['labels'].tolist()
            labels = [self.ids2cats[i] for i in labels]

            coco_results.extend(
                [
                    {
                        'image_id': original_id,
                        'category_id': labels[k],
                        'bbox': box,
                        'score': scores[k],
                    }
                    for k, box in enumerate(boxes)
                ]
            )
        return coco_results

    def prepare_for_coco_segmentation(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue
            scores = prediction['scores'].tolist()
            labels = prediction['labels'].tolist()
            masks = prediction['masks']
            masks = masks > 0.5

            rles = [
                mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order='F'))[0]
                for mask in masks
            ]
            for rle in rles:
                rle['counts'] = rle['counts'].decode('utf-8')

            coco_results.extend(
                [
                    {
                        'image_id': original_id,
                        'category_id': labels[k],
                        'segmentation': rle,
                        'score': scores[k],
                    }
                    for k, rle in enumerate(rles)
                ]
            )
        return coco_results


    def prepare_for_coco_keypoint(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue
            boxes = prediction['boxes']
            boxes = convert_to_xywh(boxes).tolist()
            scores = prediction['scores'].tolist()
            labels = prediction['labels'].tolist()
            keypoints = prediction['keypoints']
            keypoints = keypoints.flatten(start_dim=1).tolist()

            coco_results.extend(
                [
                    {
                        'image_id': original_id,
                        'category_id': labels[k],
                        'keypoints': keypoint,
                        'score': scores[k],
                    }
                    for k, keypoint in enumerate(keypoints)
                ]
            )
        return coco_results


def convert_to_xywh(boxes):
    #xmin, ymin, xmax, ymax = boxes.unbind(1)
    #return paddle.stack((xmin, ymin, xmax - xmin, ymax - ymin), axis=1)
    xmin, ymin, xmax, ymax = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    return np.stack((xmin, ymin, xmax-xmin, ymax-ymin), axis=1)


def merge(img_ids, eval_imgs):
    #all_img_ids = [img_ids]
    #all_eval_imgs = [eval_imgs]
    all_img_ids = all_gather(img_ids)
    all_eval_imgs = all_gather(eval_imgs)

    merged_img_ids = []
    for p in all_img_ids:
        merged_img_ids.extend(p)

    merged_eval_imgs = []
    for p in all_eval_imgs:
        merged_eval_imgs.append(p)

    merged_img_ids = np.array(merged_img_ids)
    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)

    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
    merged_eval_imgs = merged_eval_imgs[..., idx]

    return merged_img_ids, merged_eval_imgs


def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
    img_ids, eval_imgs = merge(img_ids, eval_imgs)
    img_ids = list(img_ids)
    eval_imgs = list(eval_imgs.flatten())

    coco_eval.evalImgs = eval_imgs
    coco_eval.params.imgIds = img_ids
    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################


def evaluate(self):
    '''
    Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
    :return: None
    '''
    # tic = time.time()
    # print('Running per image evaluation...')
    p = self.params
    # add backward compatibility if useSegm is specified in params
    if p.useSegm is not None:
        p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
        print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
    # print('Evaluate annotation type *{}*'.format(p.iouType))
    p.imgIds = list(np.unique(p.imgIds))
    if p.useCats:
        p.catIds = list(np.unique(p.catIds))
    p.maxDets = sorted(p.maxDets)
    self.params = p


    self._prepare()
    # loop through images, area range, max detection number
    catIds = p.catIds if p.useCats else [-1]

    if p.iouType == 'segm' or p.iouType == 'bbox':
        computeIoU = self.computeIoU
    elif p.iouType == 'keypoints':
        computeIoU = self.computeOks
    self.ious = {
        (imgId, catId): computeIoU(imgId, catId)
        for imgId in p.imgIds
        for catId in catIds}

    evaluateImg = self.evaluateImg
    maxDet = p.maxDets[-1]
    evalImgs = [
        evaluateImg(imgId, catId, areaRng, maxDet)
        for catId in catIds
        for areaRng in p.areaRng
        for imgId in p.imgIds
    ]
    # this is NOT in the pycocotools code, but could be done outside
    evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
    self._paramsEval = copy.deepcopy(self.params)
    # toc = time.time()
    # print('DONE (t={:0.2f}s).'.format(toc-tic))
    return p.imgIds, evalImgs

#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
